import numpy as np
import tensorflow as tf

from time import gmtime, strftime
from sys import argv

from mtmlcv.config import Config
from mtmlcv.data_loader import Data_Loader
from mtmlcv.joint_model import JointAE
from mtmlcv.loss import Loss
from mtmlcv.plot import Plot

def main():

    config = Config.from_file(argv[1])
    fout = open("log.{}".format(config.project), "w+", 1)

    num_epochs = config.num_epochs
    log_freq = config.log_freq
    check_freq = config.check_freq

    lr = config.lr
    lr_ratio = config.lr_ratio
    lr_sched = config.lr_sched

    mpc = config.pc if config.use_pe else 0
    mlc = config.lc if config.use_label else 0
    mrc = config.rc if config.use_reconst else 0
    mdc = config.dc if config.use_dist else 0
    mregc = config.regc

    # data loader
    data = Data_Loader(
        filename=config.train_data,
        shuffle=config.shuffle,
        input_label=config.input_label,
        target_label=config.target_label,
        n_sample=config.n_sample,
        batch_size=config.batch_size,
        weight_by_pe=config.weight_by_pe,
        weight_by_label=config.weight_by_label,
    )

    # set up nn
    net = JointAE(
        model_name=config.project,
        n_features=config.n_features,
        n_latent=config.n_latent,
        encoder_arch=config.encoder_arch,
        decoder_arch=config.decoder_arch,
        dropoutrate=config.dropoutrate,
        bn=config.bn,
        use_reconst=config.use_reconst,
        use_labels=config.use_label,
        labels_arch=config.labels_arch,
        onehot=config.onehot,
        nclass=config.nclass,
        use_pe=config.use_pe,
        pe_arch=config.pe_arch,
        use_dist=config.use_dist,
        gaus_latent=config.gaus_latent,
        gaus_pe=config.gaus_pe,
        gaus_pe_latent=config.gaus_pe_latent,
        input_mean=None,
        input_std=None,
        output_mean=data.output_mean,
        output_std=data.output_std,
    )

    ph_lr = tf.placeholder("float32", None)
    optim = tf.train.AdamOptimizer(
        learning_rate=ph_lr,
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-08,
        use_locking=False,
        name="Adam",
    )

    # define loss function
    coeffs = {}
    for k in Loss.loss_terms:
        coeffs[k] = tf.placeholder("float32", None)
    l = Loss(net, coeffs)

    # get place holder for training and loss
    nextx = data.iterator.get_next()
    ph_loss, ph_details, ph_pred, ph_data = l.joint_loss_function(net, nextx)
    ph_metrics = l.compute_MAE(ph_pred, ph_data)
    training_op = optim.minimize(ph_loss)

    init = tf.global_variables_initializer()
    checkpoint = config.project

    sess = tf.keras.backend.get_session()

    sess.run(init)

    # count the number of batches
    total_batch = 0
    sess.run(data.iterator.initializer)
    try:
        while True:
            sess.run(nextx)
            total_batch += 1
    except tf.errors.OutOfRangeError:
        pass

    idrun = 0
    stop = False

    for epoch in range(num_epochs):
        pc = np.random.rand() * mpc
        lc = np.random.rand() * mlc
        rc = np.random.rand() * mrc
        dc = np.random.rand() * mdc
        regc = np.random.rand() * mregc

        para_dict = {
            l.reconst_coeff: 0,
            l.pe_coeff: pc,
            l.label_coeff: lc,
            l.dist_coeff: 0,
            l.reg_coeff: 0,
            ph_lr: lr,
        }
        para_str = {"lc": lc, "rc": rc, "pc": pc, "dc": dc, "regc": regc, "lr": lr}

        sess.run(data.iterator.initializer)
        for batch in range(total_batch):

            _, loss, details, pred, data_true, metrics = sess.run(
                (training_op, ph_loss, ph_details, ph_pred, ph_data, ph_metrics),
                feed_dict=para_dict,
            )

            if (batch + 1) % log_freq == 0 or (batch + 1) == total_batch:

                allstring = "Epoch [{}/{}]".format(epoch + 1, num_epochs)
                allstring += ", Step [{}/{}]".format(batch + 1, total_batch)
                s = l.tostr(allstring, loss, details, metrics, para_str)
                print(s)
                print(s, file=fout)

        if (epoch + 1) % lr_sched == 0:
            lr = lr / lr_ratio

        if (epoch + 1) % check_freq == 0 or (epoch + 1) == num_epochs:
            net.save_weights(net.name, epoch+1)


    net.save_graph(sess)


    # print test loss mae
    test_data = Data_Loader(
        filename=config.test_data,
        shuffle=config.shuffle,
        input_label=config.input_label,
        target_label=config.target_label,
        n_sample=config.n_sample,
        batch_size=config.batch_size,
        weight_by_pe=config.weight_by_pe,
        weight_by_label=config.weight_by_label,
        test_only=True,
    )
    next_testx = test_data.iterator.get_next()
    tph_loss, tph_details, tph_pred, tph_data = l.joint_loss_function(
        net, next_testx
    )
    tph_metrics = l.compute_MAE(tph_pred, tph_data)

    # count the number of batches
    total_batch = 0
    sess.run(test_data.iterator.initializer)
    loss, details, pred, __, metrics = sess.run(
        (tph_loss, tph_details, tph_pred, tph_data, tph_metrics),
        feed_dict=para_dict
    )
    allstring = "*Epoch [{}/{}]".format(epoch + 1, num_epochs)
    allstring += ", Step [{}/{}]".format(batch + 1, total_batch)
    s = l.tostr(allstring, loss, details, metrics, para_str)
    print(s)
    print(s, file=fout)

    fout.close()

    ## plot latent space
    sess.run(data.iterator.initializer)
    sess.run(test_data.iterator.initializer)

    _, __, ph_pred, ph_data = l.joint_loss_function(net, nextx, training=False)
    _, __, ph_tpred, ph_tdata = l.joint_loss_function(net, next_testx, training=False)

    pred, train_data = sess.run((ph_pred, ph_data))
    tpred, test_data = sess.run((ph_tpred, ph_tdata))

    np.savez(f"{config.project}_pred.npz", **pred)
    np.savez(f"{config.project}_data.npz", **train_data)
    np.savez(f"{config.project}_test_pred.npz", **tpred)
    np.savez(f"{config.project}_test_data.npz", **test_data)

    if net.use_pe:
        def get_pe(z):
            ph_pe = net.pe_net(z, training=False)
            return sess.run(ph_pe)
    else:
        get_pe = None

    if net.use_labels:
        def get_label(z):
            ph_label = net.label_net(z, training=False)
            return sess.run(ph_label)
    else:
        get_label = None

    intc = transform(train_data["xyz"])
    p = Plot(
        config.project + "_train",
        train_data,
        pred,
        net.n_latent,
        intc,
        save_data=False,
        svg=False,
    )
    p.plot(get_pe, get_label)

    intc = transform(test_data["xyz"])
    p = Plot(
        config.project + "_test",
        test_data,
        tpred,
        net.n_latent,
        intc,
        save_data=False,
        svg=False,
    )
    p.plot(get_pe, get_label)

def transform(X):
    return np.vstack([np.sqrt(X[:, 0]**2+X[:, 1]**2+1.0e-7*X[:, 4]**2),
       np.sqrt(X[:, 3]**2+X[:, 2]**2)]).T

if __name__ == "__main__":
    main()
